package onion.mqtt.server.processor;

import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectPayload;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.AttributeKey;
import onion.mqtt.server.MqttServerBuilder;
import onion.mqtt.server.MqttServerConfig;
import onion.mqtt.server.MqttServerConst;
import onion.mqtt.server.auth.IMqttServerConnectHandler;
import onion.mqtt.server.event.IMqttServerConnectListener;
import onion.mqtt.server.manager.MessageManager;
import onion.mqtt.server.manager.SessionManager;
import onion.mqtt.server.manager.SubscribeManager;
import onion.mqtt.server.store.MessageStore;
import onion.mqtt.server.store.SessionStore;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author Mr, Lu
 * @developmentTeam 浙江允泽信息科技有限公司
 * @createTime 2023/12/12
 */
public class ConnectProcessor extends AbstractMqttServerProcessor<MqttConnectMessage> {
    static final Logger log = LoggerFactory.getLogger(ConnectProcessor.class);
    private final MqttServerConfig config;
    private final IMqttServerConnectHandler connectAuth;
    private final IMqttServerConnectListener connectStatusListener;

    public ConnectProcessor(MqttServerBuilder serverBuilder) {
        this.config = serverBuilder.getConfig();
        this.connectAuth = serverBuilder.getConnectHandler();
        this.connectStatusListener = serverBuilder.getConnectListener();
    }

    @Override
    public void process(Channel channel, MqttConnectMessage message) {
        MqttConnectPayload payload = message.payload();
        String clientId = payload.clientIdentifier();

        // clientId为空或null的情况, 这里要求客户端必须提供clientId, 不管cleanSession是否为1, 此处没有参考标准协议实现
        if (StringUtils.isBlank(clientId)) {
            writeAndFlush(channel, MqttMessageBuilders.connAck()
                    .returnCode(MqttConnectReturnCode.CONNECTION_REFUSED_IDENTIFIER_REJECTED)
                    .sessionPresent(false)
                    .build());
            close(channel);
            return;
        }

        // 认证
        String username = payload.userName();
        byte[] password = payload.passwordInBytes();
        if (connectAuth != null && !connectAuth.verifyAuthenticate(channel, clientId, username, new String(password))) {
            writeAndFlush(channel, MqttMessageBuilders.connAck()
                    .returnCode(MqttConnectReturnCode.CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD)
                    .sessionPresent(false)
                    .build());
            close(channel);
            return;
        }

        // 将clientId存储到channel的Map中，以便后续能获取到clientId
        channel.attr(AttributeKey.valueOf(MqttServerConst.CLIENT_ID)).set(clientId);

        // 如果会话中存在这个clientId，就关闭之前的连接
        if (SessionManager.getInstance().hasSession(clientId)) {
            SessionStore sessionStore = SessionManager.getInstance().getSession(clientId);
            if (sessionStore != null) {
                // 关闭之前连接
                sessionStore.getChannel().close();
                // 如果之前的连接为isCleanSession则清空之前连接信息
                if (sessionStore.isCleanSession()) {
                    SubscribeManager.getInstance().clearSubscribeByClient(clientId);
                    SessionManager.getInstance().removeSession(clientId);
                }
            }
        } else {
            SubscribeManager.getInstance().clearSubscribeByClient(clientId);
        }

        // 处理心跳包,重设心跳时间
        AtomicInteger expire = new AtomicInteger();
        if (message.variableHeader().keepAliveTimeSeconds() > 0) {
            if (channel.pipeline().names().contains("idle")) {
                channel.pipeline().remove("idle");
                expire.set(Math.round(message.variableHeader().keepAliveTimeSeconds() * 1.5f));
                channel.pipeline().addFirst("idle",new IdleStateHandler(expire.get(), config.getKeepAlive(), config.getKeepAlive(), TimeUnit.SECONDS));
            }
        }

        // 将新会话存储起来
        SessionStore sessionStore = new SessionStore();
        sessionStore.setClientId(clientId);
        sessionStore.setChannel(channel);
        sessionStore.setExpire(expire.get());
        sessionStore.setCleanSession(message.variableHeader().isCleanSession());
        SessionManager.getInstance().addSession(sessionStore);

        // 清除遗嘱消息
        MessageManager.getInstance().removeWillMessageByClient(clientId);

        // 存储遗嘱消息
        if (message.variableHeader().isWillFlag()) {
            MessageStore willMessage = new MessageStore();
            willMessage.setClientId(clientId);
            willMessage.setTopic(payload.willTopic());
            willMessage.setQoS(message.variableHeader().willQos());
            willMessage.setRetain(message.variableHeader().isWillRetain());
            if (message.payload().willMessageInBytes() != null) {
                willMessage.setPayload(message.payload().willMessageInBytes());
            }
            willMessage.setTimestamp(System.currentTimeMillis());
            willMessage.setNodeId(config.getNodeId());
            MessageManager.getInstance().addWillMessage(willMessage);
        }

        // 回复connAck
        writeAndFlush(channel, MqttMessageBuilders.connAck()
                .returnCode(MqttConnectReturnCode.CONNECTION_ACCEPTED)
                .sessionPresent(true)
                .build());

        // 更新在线状态
        CompletableFuture.runAsync(() -> {
            try {
                if (connectStatusListener != null) {
                    connectStatusListener.online(channel, clientId);
                }
            } catch (Throwable e) {
                log.error("connect publishEvent error clientId: {}", clientId);
            }
        });
    }
}
